package vip.aster.framework.mybatis.handler;

import cn.hutool.core.annotation.AnnotationUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ConcurrentHashSet;
import cn.hutool.core.util.ClassUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import vip.aster.common.constant.enums.SuperAdminEnum;
import vip.aster.common.exception.BusinessException;
import vip.aster.framework.mybatis.annotation.DataColumn;
import vip.aster.framework.mybatis.annotation.DataPermission;
import vip.aster.framework.mybatis.enums.ColumnTypeEnum;
import vip.aster.framework.security.entity.SecurityUser;
import vip.aster.framework.security.entity.UserDetail;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

/**
 * 数据权限过滤
 *
 * @author Aster
 * @since 2024/3/21 11:28
 */
@Slf4j
public class AsterDataPermissionHandler {
    /**
     * 方法或类(名称) 与 注解的映射关系缓存
     */
    private final Map<String, DataPermission> dataPermissionCacheMap = new ConcurrentHashMap<>();

    /**
     * 无效注解方法缓存用于快速返回
     */
    private final Set<String> invalidCacheSet = new ConcurrentHashSet<>();

    /**
     * 间隔符
     */
    private static final String EMPTY_STR = " ";

    public Expression getSqlSegment(Expression where, String mappedStatementId, boolean isSelect) {
        log.info("开始进行权限过滤,where: {},mappedStatementId: {}", where, mappedStatementId);
        DataColumn[] dataColumns = findAnnotation(mappedStatementId);
        if (dataColumns == null || dataColumns.length == 0) {
            invalidCacheSet.add(mappedStatementId);
            return where;
        }

        String dataFilterSql = buildDataFilter(dataColumns, isSelect);
        if (StrUtil.isBlank(dataFilterSql)) {
            return where;
        }
        try {
            log.info("sql：{}", dataFilterSql);
            Expression expression = CCJSqlParserUtil.parseCondExpression(dataFilterSql);
            // 数据权限使用单独的括号 防止与其他条件冲突
            Parenthesis parenthesis = new Parenthesis(expression);
            if (ObjectUtil.isNotNull(where)) {
                return new AndExpression(where, parenthesis);
            } else {
                return parenthesis;
            }
        } catch (JSQLParserException e) {
            throw new BusinessException("数据权限解析异常 => " + e.getMessage());
        }
    }

    /**
     * 获取注解参数
     *
     * @param mappedStatementId
     * @return
     */
    private DataColumn[] findAnnotation(String mappedStatementId) {
        StringBuilder sb = new StringBuilder(mappedStatementId);
        int index = sb.lastIndexOf(".");
        String clazzName = sb.substring(0, index);
        String methodName = sb.substring(index + 1, sb.length());
        Class<?> clazz = ClassUtil.loadClass(clazzName);
        List<Method> methods = Arrays.stream(ClassUtil.getDeclaredMethods(clazz))
                .filter(method -> method.getName().equals(methodName)).toList();
        DataPermission dataPermission;
        // 获取方法注解
        for (Method method : methods) {
            dataPermission = dataPermissionCacheMap.get(mappedStatementId);
            if (ObjectUtil.isNotNull(dataPermission)) {
                return dataPermission.value();
            }
            dataPermission = AnnotationUtil.getAnnotation(method, DataPermission.class);
            if (ObjectUtil.isNotNull(dataPermission)) {
                dataPermissionCacheMap.put(mappedStatementId, dataPermission);
                return dataPermission.value();
            }
        }
        dataPermission = dataPermissionCacheMap.get(clazz.getName());
        if (ObjectUtil.isNotNull(dataPermission)) {
            return dataPermission.value();
        }
        // 获取类注解
        dataPermission = AnnotationUtil.getAnnotation(clazz, DataPermission.class);
        if (ObjectUtil.isNotNull(dataPermission)) {
            dataPermissionCacheMap.put(clazz.getName(), dataPermission);
            return dataPermission.value();
        }
        return null;
    }


    /**
     * 构造数据过滤sql
     *
     * @return 过滤sql
     */
    private String buildDataFilter(DataColumn[] dataColumns, boolean isSelect) {
        // 当前用户
        UserDetail userDetail = SecurityUser.getUser();
        if (userDetail == null) {
            throw new BusinessException("请登录后再操作！");
        }

        // 如果是超级管理员，则不过滤数据
        if (SuperAdminEnum.YES.getCode().equals(userDetail.getSuperAdmin())) {
            return null;
        }

        // 数据权限-机构ID
        List<String> orgIds = userDetail.getDataScopeList();
        // 若数据权限为null,则不过滤数据
        if (orgIds == null) {
            return null;
        }
        // 更新或删除需满足所有条件
        String joinStr = isSelect ? " OR " : " AND ";

        // sql语句
        StringBuffer sb = new StringBuffer();
        for (DataColumn dataColumn : dataColumns) {

            // 表别名
            String tableAlias = StrUtil.isNotBlank(dataColumn.alias()) ? dataColumn.alias() + "." : "";
            // 表字段
            String tableField = dataColumn.value();
            // 字段类型
            String fieldType = dataColumn.type().getValue();

            if (ColumnTypeEnum.DEPT.getValue().equals(fieldType)) {
                if (CollUtil.isEmpty(orgIds)) {
                    if (!isSelect) {
                        return " 1 = 2 ";
                    } else {
                        continue;
                    }
                }
                sb.append(tableAlias).append(tableField).append(" IN ").append(convertListIdToString(orgIds));
                sb.append(joinStr);
            } else if (ColumnTypeEnum.USER.getValue().equals(fieldType)) {
                sb.append(tableAlias).append(tableField).append(" = ").append(userDetail.getId());
                sb.append(joinStr);
            }
        }

        // sql规范
        String sqlStr = sb.toString();
        if (sqlStr.endsWith(joinStr)) {
            sqlStr = sqlStr.substring(0, sqlStr.lastIndexOf(joinStr));
        }
        return StrUtil.isNotBlank(sqlStr) ? sqlStr : " 1=2 ";
    }

    /**
     * 是否为无效方法 无数据权限
     */
    public boolean isInvalid(String mappedStatementId) {
        return invalidCacheSet.contains(mappedStatementId) || !this.hasDataPermissionAnotation(mappedStatementId);
    }

    /**
     * 判断是否存在数据权限注解
     */
    private boolean hasDataPermissionAnotation(String mappedStatementId) {

        if (ObjectUtil.isNotNull(dataPermissionCacheMap.get(mappedStatementId))) {
            return true;
        }
        StringBuilder sb = new StringBuilder(mappedStatementId);
        int index = sb.lastIndexOf(".");
        String clazzName = sb.substring(0, index);
        String methodName = sb.substring(index + 1, sb.length());
        Class<?> clazz = ClassUtil.loadClass(clazzName);
        List<Method> methods = Arrays.stream(ClassUtil.getDeclaredMethods(clazz))
                .filter(method -> method.getName().equals(methodName)).collect(Collectors.toList());

        DataPermission dataPermission;
        // 获取类注解
        if (!AnnotationUtil.hasAnnotation(clazz, DataPermission.class)) {
            if (CollUtil.isEmpty(methods)) {
                invalidCacheSet.add(mappedStatementId);
                return false;
            }

            // 获取方法注解
            for (Method method : methods) {
                dataPermission = dataPermissionCacheMap.get(mappedStatementId);
                if (ObjectUtil.isNotNull(dataPermission)) {
                    return true;
                }
                if (AnnotationUtil.hasAnnotation(method, DataPermission.class)) {
                    dataPermission = AnnotationUtil.getAnnotation(method, DataPermission.class);
                    dataPermissionCacheMap.put(mappedStatementId, dataPermission);
                    return true;
                }
            }

            invalidCacheSet.add(mappedStatementId);
            return false;
        } else {
            dataPermission = AnnotationUtil.getAnnotation(clazz, DataPermission.class);
            dataPermissionCacheMap.put(clazz.getName(), dataPermission);
            return true;
        }
    }

    /**
     * 将List<String>转化为String
     *
     * @param ids 机构id
     * @return 机构字符串
     */
    private String convertListIdToString(List<String> ids) {
        if (CollUtil.isEmpty(ids)) {
            return null;
        }
        StringBuffer sb = new StringBuffer();
        sb.append(" (");
        ids.forEach(id -> {
            sb.append("'").append(id).append("',");
        });
        return sb.substring(0, sb.lastIndexOf(",")).concat(") ");
    }

}
